import os
from data.dataset import load_cifar10
from models.resnet import ResNet18
from training.trainer import Trainer
from utils.config import config

def main():
    # 创建必要的目录
    os.makedirs(config.DATA_DIR, exist_ok=True)
    os.makedirs(config.LOG_DIR, exist_ok=True)
    os.makedirs(config.SAVE_DIR, exist_ok=True)
    
    # 加载数据集
    train_loader, val_loader, test_loader, classes = load_cifar10()
    print(f"Dataset loaded. Classes: {classes}")
    
    # 初始化模型
    model = ResNet18()
    print(f"Model {config.MODEL_NAME} initialized.")
    
    # 初始化训练器
    trainer = Trainer(model, train_loader, val_loader, test_loader)
    
    # 开始训练
    print("Starting training...")
    test_acc = trainer.train()
    
    print(f"Final test accuracy: {test_acc:.4f}")

if __name__ == "__main__":
    main()